import argparse
from mkd_model import MKD_Trainer

def parse_args():
    parser = argparse.ArgumentParser(description="MKD训练入口")
    parser.add_argument("--dataset", type=str, default="cifar100", choices=["cifar100", "texas100", "purchase100"], help="数据集名称")
    parser.add_argument("--model", type=str, default="resnet18", choices=["resnet18", "alexnet"], help="模型类型")
    parser.add_argument("--epochs", type=int, default=50, help="训练轮次")
    parser.add_argument("--batch_size", type=int, default=512, help="批次大小")
    parser.add_argument("--lr", type=float, default=1e-3, help="学习率")
    return parser.parse_args()

if __name__ == "__main__":
    args = parse_args()
    # 初始化MKD训练器并运行
    mkd_trainer = MKD_Trainer(
        dataset_name=args.dataset,
        model_type=args.model,
        epochs=args.epochs,
        batch_size=args.batch_size,
        lr=args.lr
    )
    mkd_trainer.run_mkd()